{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Initialization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# For Colab only!\n",
    "\n",
    "try:\n",
    "  # %tensorflow_version only exists in Colab.\n",
    "  %tensorflow_version 2.x\n",
    "except Exception:\n",
    "  pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "cell_style": "split"
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras import layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "cell_style": "split"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.nn import functional as F\n",
    "from torchvision import datasets, transforms\n",
    "from torch import nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "cell_style": "split",
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.1.0\n",
      "WARNING:tensorflow:From <ipython-input-3-2e82a26757ae>:2: is_gpu_available (from tensorflow.python.framework.test_util) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.config.list_physical_devices('GPU')` instead.\n",
      "True\n"
     ]
    }
   ],
   "source": [
    "print(tf.__version__)\n",
    "print(tf.test.is_gpu_available())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "cell_style": "split"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.4.0\n",
      "True\n"
     ]
    }
   ],
   "source": [
    "print(torch.__version__)\n",
    "print(torch.cuda.is_available())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Data Loading\n",
    "MINST data set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "batch_size=200\n",
    "learning_rate=0.01\n",
    "epochs=10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "cell_style": "split"
   },
   "outputs": [],
   "source": [
    "(x, y),(x_test, y_test) = keras.datasets.mnist.load_data()\n",
    "\n",
    "ds_train = tf.data.Dataset.from_tensor_slices((x,y))\n",
    "ds_test = tf.data.Dataset.from_tensor_slices((x_test, y_test))\n",
    "\n",
    "def preprocess(x, y):\n",
    "  x = (tf.cast(x, tf.float32)/255)-0.1307\n",
    "  y = tf.cast(y, tf.int32)\n",
    "#   y = tf.one_hot(y,depth=10)   \n",
    "  return x, y\n",
    "\n",
    "ds_train = ds_train.map(preprocess).shuffle(1000).batch(batch_size)\n",
    "ds_test = ds_test.map(preprocess).shuffle(1000).batch(batch_size)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "cell_style": "split",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "train_loader = torch.utils.data.DataLoader(\n",
    "    datasets.MNIST('../data', train=True, download=True,\n",
    "                   transform=transforms.Compose([\n",
    "                       transforms.ToTensor(),\n",
    "#                        transforms.Normalize((0.1307,), (0.3081,))\n",
    "                   ])),\n",
    "    batch_size=batch_size, shuffle=True)\n",
    "test_loader = torch.utils.data.DataLoader(\n",
    "    datasets.MNIST('../data', train=False, transform=transforms.Compose([\n",
    "        transforms.ToTensor(),\n",
    "#         transforms.Normalize((0.1307,), (0.3081,))\n",
    "    ])),\n",
    "    batch_size=batch_size, shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "cell_style": "split"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'tensorflow.python.data.ops.dataset_ops.BatchDataset'>\n",
      "(200, 28, 28) (200,)\n"
     ]
    }
   ],
   "source": [
    "print(type(ds_test))\n",
    "image, label = next(iter(ds_test))\n",
    "print(image.shape, label.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "cell_style": "split",
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'torch.utils.data.dataloader.DataLoader'>\n",
      "torch.Size([200, 1, 28, 28]) torch.Size([200])\n"
     ]
    }
   ],
   "source": [
    "print(type(train_loader))\n",
    "[image, label] = next(iter(train_loader))\n",
    "print(image.shape, label.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Visualization with tensorboard and visdom"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {
    "cell_style": "split"
   },
   "outputs": [],
   "source": [
    "import datetime, os\n",
    "current_time = datetime.datetime.now().strftime(\"%Y%m%d-%H%M%S\")\n",
    "log_dir = 'logs/' + current_time\n",
    "summary_writer = tf.summary.create_file_writer(log_dir)\n",
    "\n",
    "# in terminal: tensorboard --logdir logs\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "cell_style": "split"
   },
   "outputs": [],
   "source": [
    "from visdom import Visdom\n",
    "\n",
    "# terminal run: python -m visdom.server"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {
    "cell_style": "center"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch:0, step:0 loss:2.342546224594116\n",
      "accuracy:  0.3911\n",
      "epoch:0, step:100 loss:0.275505930185318\n",
      "accuracy:  0.9315\n",
      "epoch:0, step:200 loss:0.13893592357635498\n",
      "accuracy:  0.9546\n",
      "epoch:1, step:0 loss:0.1728418469429016\n",
      "accuracy:  0.9511\n",
      "epoch:1, step:100 loss:0.1439828872680664\n",
      "accuracy:  0.9584\n",
      "epoch:1, step:200 loss:0.054427556693553925\n",
      "accuracy:  0.9664\n",
      "epoch:2, step:0 loss:0.14143946766853333\n",
      "accuracy:  0.9607\n",
      "epoch:2, step:100 loss:0.043385621160268784\n",
      "accuracy:  0.9668\n",
      "epoch:2, step:200 loss:0.08430337905883789\n",
      "accuracy:  0.9656\n",
      "epoch:3, step:0 loss:0.1252625435590744\n",
      "accuracy:  0.97\n",
      "epoch:3, step:100 loss:0.07537724822759628\n",
      "accuracy:  0.9692\n",
      "epoch:3, step:200 loss:0.09454736113548279\n",
      "accuracy:  0.9678\n",
      "epoch:4, step:0 loss:0.18631324172019958\n",
      "accuracy:  0.9657\n",
      "epoch:4, step:100 loss:0.044349975883960724\n",
      "accuracy:  0.9626\n",
      "epoch:4, step:200 loss:0.13011203706264496\n",
      "accuracy:  0.9638\n",
      "epoch:5, step:0 loss:0.06854582577943802\n",
      "accuracy:  0.9563\n",
      "epoch:5, step:100 loss:0.12148669362068176\n",
      "accuracy:  0.9689\n",
      "epoch:5, step:200 loss:0.028568662703037262\n",
      "accuracy:  0.9713\n",
      "epoch:6, step:0 loss:0.08852332085371017\n",
      "accuracy:  0.9581\n",
      "epoch:6, step:100 loss:0.031847141683101654\n",
      "accuracy:  0.9667\n",
      "epoch:6, step:200 loss:0.04816087707877159\n",
      "accuracy:  0.9702\n",
      "epoch:7, step:0 loss:0.021321790292859077\n",
      "accuracy:  0.9695\n",
      "epoch:7, step:100 loss:0.023279208689928055\n",
      "accuracy:  0.9715\n",
      "epoch:7, step:200 loss:0.06978157162666321\n",
      "accuracy:  0.9733\n",
      "epoch:8, step:0 loss:0.05921979993581772\n",
      "accuracy:  0.9689\n",
      "epoch:8, step:100 loss:0.10814261436462402\n",
      "accuracy:  0.9647\n",
      "epoch:8, step:200 loss:0.021313974633812904\n",
      "accuracy:  0.9746\n",
      "epoch:9, step:0 loss:0.0787103921175003\n",
      "accuracy:  0.9715\n",
      "epoch:9, step:100 loss:0.03530854359269142\n",
      "accuracy:  0.9734\n",
      "epoch:9, step:200 loss:0.16958735883235931\n",
      "accuracy:  0.9745\n"
     ]
    }
   ],
   "source": [
    "class FC_model(keras.Model):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "    \n",
    "        self.model = keras.Sequential(\n",
    "            [layers.Dense(200),\n",
    "            layers.ReLU(),\n",
    "            layers.Dense(100),\n",
    "            layers.ReLU(),\n",
    "            layers.Dense(10)]\n",
    "            )\n",
    "    \n",
    "    def call(self,x):\n",
    "        x = self.model(x)\n",
    "        \n",
    "        return x\n",
    "    \n",
    "model = FC_model()\n",
    "optimizer = tf.optimizers.Adam(learning_rate)\n",
    "    \n",
    "globle_step = 0;\n",
    "\n",
    "for epoch in range(epochs):\n",
    "    \n",
    "    for step, (x, y) in enumerate(ds_train):\n",
    "        x = tf.reshape(x, [-1, 28*28])\n",
    "        with tf.GradientTape() as tape:            \n",
    "            logits = model(x)\n",
    "            \n",
    "            losses = tf.losses.sparse_categorical_crossentropy(y,logits,from_logits=True)\n",
    "            loss = tf.reduce_mean(losses)\n",
    "            \n",
    "            with summary_writer.as_default():            \n",
    "                tf.summary.scalar(\"train_loss_each_step\",loss.numpy(),step=globle_step)\n",
    "            \n",
    "        grads = tape.gradient(loss, model.variables)\n",
    "        \n",
    "        optimizer.apply_gradients(zip(grads, model.variables))\n",
    "        \n",
    "        if(step%100==0):\n",
    "            print(\"epoch:{}, step:{} loss:{}\".\n",
    "                  format(epoch, step, loss.numpy()))\n",
    "            \n",
    "            \n",
    "#             test accuracy: \n",
    "            total_correct = 0\n",
    "            total_num = 0\n",
    "            \n",
    "            for x_test, y_test in ds_test:\n",
    "                x_test = tf.reshape(x_test, [-1, 28*28])\n",
    "                logits_test = model(x_test)\n",
    "                \n",
    "                loss_test = tf.losses.sparse_categorical_crossentropy(y_test, logits_test, from_logits=True)\n",
    "                loss_test = tf.reduce_mean(loss_test)\n",
    "                \n",
    "                y_pred = tf.argmax(logits_test,axis=1)\n",
    "                \n",
    "                \n",
    "                y_pred = tf.cast(y_pred, tf.int32)\n",
    "                correct = tf.cast((y_pred == y_test), tf.int32)\n",
    "                correct = tf.reduce_sum(correct)\n",
    "                \n",
    "                total_correct += int(correct)\n",
    "                total_num += x_test.shape[0]\n",
    "        \n",
    "            \n",
    "            accuracy = total_correct/total_num\n",
    "            print('accuracy: ', accuracy)\n",
    "            \n",
    "            x_images = x_test[:25]\n",
    "            x_images = (x_test+0.1307)*255\n",
    "            x_images = tf.reshape(x_images, [-1,28,28,1])\n",
    "            \n",
    "            with summary_writer.as_default():\n",
    "                tf.summary.scalar(\"accuracy\", accuracy, step=globle_step)\n",
    "                tf.summary.scalar(\"train_loss\", loss.numpy(),step=globle_step)\n",
    "                tf.summary.scalar(\"test_loss\", loss_test.numpy(),step=globle_step)\n",
    "                \n",
    "                tf.summary.image(\"images\",x_images,step=globle_step,max_outputs=25)\n",
    "        \n",
    "        \n",
    "        globle_step += 1\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cell_style": "center"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Setting up a new session...\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch:0, step:0, loss:2.312376022338867\n",
      "accuracy:  0.20260000228881836\n",
      "epoch:0, step:100, loss:0.24057549238204956\n",
      "accuracy:  0.9429000020027161\n",
      "epoch:0, step:200, loss:0.17944090068340302\n",
      "accuracy:  0.9506999850273132\n",
      "epoch:1, step:0, loss:0.04460395500063896\n",
      "accuracy:  0.9589999914169312\n",
      "epoch:1, step:100, loss:0.10701652616262436\n",
      "accuracy:  0.9623000025749207\n",
      "epoch:1, step:200, loss:0.10582515597343445\n",
      "accuracy:  0.9624999761581421\n",
      "epoch:2, step:0, loss:0.13842931389808655\n",
      "accuracy:  0.9690999984741211\n",
      "epoch:2, step:100, loss:0.13088180124759674\n",
      "accuracy:  0.9680999517440796\n",
      "epoch:2, step:200, loss:0.10829789191484451\n",
      "accuracy:  0.9501000046730042\n",
      "epoch:3, step:0, loss:0.13796570897102356\n",
      "accuracy:  0.9673999547958374\n",
      "epoch:3, step:100, loss:0.08032510429620743\n",
      "accuracy:  0.9721999764442444\n",
      "epoch:3, step:200, loss:0.030323320999741554\n",
      "accuracy:  0.9740999937057495\n",
      "epoch:4, step:0, loss:0.05643593892455101\n",
      "accuracy:  0.9724999666213989\n",
      "epoch:4, step:100, loss:0.06288091838359833\n",
      "accuracy:  0.9681999683380127\n",
      "epoch:4, step:200, loss:0.06826059520244598\n",
      "accuracy:  0.9718999862670898\n",
      "epoch:5, step:0, loss:0.03170830383896828\n",
      "accuracy:  0.9761999845504761\n",
      "epoch:5, step:100, loss:0.037314578890800476\n",
      "accuracy:  0.9745000004768372\n",
      "epoch:5, step:200, loss:0.01824205555021763\n",
      "accuracy:  0.9728999733924866\n",
      "epoch:6, step:0, loss:0.06442216038703918\n",
      "accuracy:  0.9728999733924866\n",
      "epoch:6, step:100, loss:0.06548592448234558\n",
      "accuracy:  0.9648999571800232\n",
      "epoch:6, step:200, loss:0.0713154524564743\n",
      "accuracy:  0.9680999517440796\n",
      "epoch:7, step:0, loss:0.013396530412137508\n",
      "accuracy:  0.9721999764442444\n",
      "epoch:7, step:100, loss:0.0777147114276886\n",
      "accuracy:  0.9739999771118164\n",
      "epoch:7, step:200, loss:0.048254676163196564\n",
      "accuracy:  0.9716999530792236\n",
      "epoch:8, step:0, loss:0.02983185090124607\n",
      "accuracy:  0.964199960231781\n",
      "epoch:8, step:100, loss:0.03215639665722847\n"
     ]
    }
   ],
   "source": [
    "class FC_NN(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "    \n",
    "        self.model = nn.Sequential(\n",
    "            nn.Linear(28*28, 200),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.Linear(200, 100),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.Linear(100,10)\n",
    "            )\n",
    "    \n",
    "    def forward(self, x):\n",
    "        x = self.model(x)\n",
    "        \n",
    "        return x\n",
    "device = torch.device('cuda:0')\n",
    "\n",
    "network = FC_NN().to(device)        \n",
    "optimizer = torch.optim.Adam(network.parameters(),\n",
    "                            lr=learning_rate)\n",
    "criteon = torch.nn.CrossEntropyLoss().to(device)\n",
    "\n",
    "vis = Visdom()\n",
    "\n",
    "globle_step = 0\n",
    "\n",
    "for epoch in range(epochs):\n",
    "    \n",
    "    for step, (x, y) in enumerate(train_loader):\n",
    "        x = x.reshape(-1,28*28)\n",
    "        \n",
    "        x, y = x.to(device), y.to(device)\n",
    "        \n",
    "        logits = network(x)\n",
    "        loss = criteon(logits, y)\n",
    "        \n",
    "        \n",
    "        vis.line([loss.item()],[globle_step],win='loss_each_step',\n",
    "                 update='append',opts=dict(title='train_loss_each_step') )\n",
    "        \n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        if(step%100 == 0):\n",
    "            print(\"epoch:{}, step:{}, loss:{}\".\n",
    "                  format(epoch, step, loss.item()))\n",
    "        \n",
    "#             test accuracy\n",
    "            total_correct = 0\n",
    "            total_num = 0    \n",
    "\n",
    "            for x_test, y_test in test_loader:\n",
    "                    x_test = x_test.reshape(-1,28*28)\n",
    "                    x_test, y_test = x_test.to(device), y_test.to(device)\n",
    "\n",
    "                    logits_test = network(x_test)\n",
    "                    loss_test = criteon(logits_test, y_test)\n",
    "                    \n",
    "                    y_pred = torch.argmax(logits_test, dim = 1)\n",
    "                    correct = y_pred == y_test\n",
    "                    correct = correct.sum()\n",
    "\n",
    "                    total_correct += correct\n",
    "                    total_num += x_test.shape[0]\n",
    "\n",
    "            acc = total_correct.float()/total_num\n",
    "            print(\"accuracy: \", acc.item())\n",
    "            \n",
    "            vis.line([acc.item()],[globle_step],\n",
    "                     win='acc', update='append',\n",
    "                    opts=dict(title = 'accuracy'))\n",
    "            \n",
    "            vis.line([[loss.item(), loss_test.item()]],[globle_step], \n",
    "                             win='losses', update='append', \n",
    "                             opts=dict(title='losses',legend=['train_loss', 'test_loss'] ))\n",
    "            \n",
    "            vis.images(x_test.reshape(-1,1,28,28), win='images')\n",
    "            vis.text(str(y_pred.detach().cpu().numpy()), win = 'pred')\n",
    "    \n",
    "        globle_step += 1           \n",
    "                "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Tensorflow with visdom"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Setting up a new session...\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch:0, step:0 loss:2.360940933227539\n",
      "accuracy:  0.3759\n",
      "epoch:0, step:100 loss:0.15554414689540863\n",
      "accuracy:  0.9399\n",
      "epoch:0, step:200 loss:0.09762094169855118\n",
      "accuracy:  0.9536\n",
      "epoch:1, step:0 loss:0.0583413802087307\n",
      "accuracy:  0.9534\n",
      "epoch:1, step:100 loss:0.11694678664207458\n",
      "accuracy:  0.9517\n",
      "epoch:1, step:200 loss:0.09646661579608917\n",
      "accuracy:  0.964\n",
      "epoch:2, step:0 loss:0.07271380722522736\n",
      "accuracy:  0.9623\n",
      "epoch:2, step:100 loss:0.07763439416885376\n",
      "accuracy:  0.9613\n",
      "epoch:2, step:200 loss:0.06338293850421906\n",
      "accuracy:  0.9652\n",
      "epoch:3, step:0 loss:0.09928105026483536\n",
      "accuracy:  0.9664\n",
      "epoch:3, step:100 loss:0.0777912363409996\n",
      "accuracy:  0.9688\n",
      "epoch:3, step:200 loss:0.09107643365859985\n",
      "accuracy:  0.968\n",
      "epoch:4, step:0 loss:0.14426881074905396\n",
      "accuracy:  0.962\n",
      "epoch:4, step:100 loss:0.08075625449419022\n",
      "accuracy:  0.9709\n",
      "epoch:4, step:200 loss:0.10163595527410507\n",
      "accuracy:  0.9682\n",
      "epoch:5, step:0 loss:0.07181905955076218\n",
      "accuracy:  0.9662\n",
      "epoch:5, step:100 loss:0.018445981666445732\n",
      "accuracy:  0.9739\n",
      "epoch:5, step:200 loss:0.10279613733291626\n",
      "accuracy:  0.9605\n",
      "epoch:6, step:0 loss:0.03103221394121647\n",
      "accuracy:  0.9663\n",
      "epoch:6, step:100 loss:0.02783619798719883\n",
      "accuracy:  0.9664\n",
      "epoch:6, step:200 loss:0.07177065312862396\n",
      "accuracy:  0.972\n",
      "epoch:7, step:0 loss:0.0233557540923357\n",
      "accuracy:  0.9665\n",
      "epoch:7, step:100 loss:0.0869794487953186\n",
      "accuracy:  0.9656\n",
      "epoch:7, step:200 loss:0.11168929934501648\n",
      "accuracy:  0.9693\n",
      "epoch:8, step:0 loss:0.16089022159576416\n",
      "accuracy:  0.9669\n",
      "epoch:8, step:100 loss:0.05894499644637108\n",
      "accuracy:  0.9686\n",
      "epoch:8, step:200 loss:0.0439315028488636\n",
      "accuracy:  0.9694\n",
      "epoch:9, step:0 loss:0.24045337736606598\n",
      "accuracy:  0.9659\n",
      "epoch:9, step:100 loss:0.08383455872535706\n",
      "accuracy:  0.9755\n",
      "epoch:9, step:200 loss:0.056969426572322845\n",
      "accuracy:  0.9738\n"
     ]
    }
   ],
   "source": [
    "class FC_model(keras.Model):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "    \n",
    "        self.model = keras.Sequential(\n",
    "            [layers.Dense(200),\n",
    "            layers.ReLU(),\n",
    "            layers.Dense(100),\n",
    "            layers.ReLU(),\n",
    "            layers.Dense(10)]\n",
    "            )\n",
    "    \n",
    "    def call(self,x):\n",
    "        x = self.model(x)\n",
    "        \n",
    "        return x\n",
    "    \n",
    "model = FC_model()\n",
    "optimizer = tf.optimizers.Adam(learning_rate)\n",
    "\n",
    "vis = Visdom()\n",
    "    \n",
    "globle_step = 0;\n",
    "\n",
    "for epoch in range(epochs):\n",
    "    \n",
    "    for step, (x, y) in enumerate(ds_train):\n",
    "        x = tf.reshape(x, [-1, 28*28])\n",
    "        with tf.GradientTape() as tape:            \n",
    "            logits = model(x)\n",
    "            \n",
    "            losses = tf.losses.sparse_categorical_crossentropy(y,logits,from_logits=True)\n",
    "            loss = tf.reduce_mean(losses)\n",
    "            \n",
    "            vis.line([loss.numpy()],[globle_step],win='loss_each_step',\n",
    "                 update='append',opts=dict(title='train_loss_each_step'))\n",
    "            \n",
    "        grads = tape.gradient(loss, model.variables)\n",
    "        \n",
    "        optimizer.apply_gradients(zip(grads, model.variables))\n",
    "        \n",
    "        if(step%100==0):\n",
    "            print(\"epoch:{}, step:{} loss:{}\".\n",
    "                  format(epoch, step, loss.numpy()))\n",
    "            \n",
    "            \n",
    "#             test accuracy: \n",
    "            total_correct = 0\n",
    "            total_num = 0\n",
    "            \n",
    "            for x_test, y_test in ds_test:\n",
    "                x_test = tf.reshape(x_test, [-1, 28*28])\n",
    "                logits_test = model(x_test)\n",
    "                \n",
    "                loss_test = tf.losses.sparse_categorical_crossentropy(y_test, logits_test, from_logits=True)\n",
    "                loss_test = tf.reduce_mean(loss_test)\n",
    "                \n",
    "                y_pred = tf.argmax(logits_test,axis=1)\n",
    "                \n",
    "                y_pred = tf.cast(y_pred, tf.int32)\n",
    "                correct = tf.cast((y_pred == y_test), tf.int32)\n",
    "                correct = tf.reduce_sum(correct)\n",
    "                \n",
    "                total_correct += int(correct)\n",
    "                total_num += x_test.shape[0]\n",
    "        \n",
    "            \n",
    "            accuracy = total_correct/total_num\n",
    "            print('accuracy: ', accuracy)\n",
    "\n",
    "            vis.line([accuracy], [globle_step], win='acc', update='append',\n",
    "                    opts=dict(title='accuracy'))\n",
    "    \n",
    "            vis.line([[loss.numpy(),loss_test.numpy()]],[globle_step], win='losses', \n",
    "                     opts=dict(title='loss:train vs test', legend=['train','test']),\n",
    "                     update='append')\n",
    "\n",
    "    \n",
    "            x_images = (x_test+0.1307)*255\n",
    "            x_images = tf.reshape(x_images, [-1,1,28,28])\n",
    "            vis.images(x_images.numpy(), win='images')\n",
    "            \n",
    "            vis.text(str(y_pred.numpy()),win='pred')\n",
    "\n",
    "        \n",
    "        \n",
    "        globle_step += 1\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
